#!/usr/bin/env python3

import argparse
from collections import OrderedDict, namedtuple
import os
import sys

from utilities_common import constants
from natsort import natsorted
from tabulate import tabulate
from sonic_py_common import multi_asic
from swsscommon.swsscommon import APP_FABRIC_PORT_TABLE_NAME, COUNTERS_TABLE, COUNTERS_FABRIC_PORT_NAME_MAP, COUNTERS_FABRIC_QUEUE_NAME_MAP
import utilities_common.multi_asic as multi_asic_util

# mock the redis for unit test purposes #
try:
    if os.environ["UTILITIES_UNIT_TESTING"] == "2":
        modules_path = os.path.join(os.path.dirname(__file__), "..")
        tests_path = os.path.join(modules_path, "tests")
        sys.path.insert(0, modules_path)
        sys.path.insert(0, tests_path)
        import mock_tables.dbconnector
    if os.environ["UTILITIES_UNIT_TESTING_TOPOLOGY"] == "multi_asic":
        import mock_tables.mock_multi_asic
        mock_tables.dbconnector.load_namespace_config()
except KeyError:
    pass

PORT_NAME_PREFIX = 'PORT'
COUNTER_TABLE_PREFIX = COUNTERS_TABLE+":"
FABRIC_PORT_STATUS_TABLE_PREFIX = APP_FABRIC_PORT_TABLE_NAME+"|"
FABRIC_PORT_STATUS_FIELD = "STATUS"
STATUS_NA = 'N/A'

class FabricStat(object):
    def __init__(self, namespace):
        self.db = None
        self.namespace = namespace
        self.multi_asic = multi_asic_util.MultiAsic(constants.DISPLAY_ALL, namespace)

    def get_cnstat_dict(self):
        self.cnstat_dict = OrderedDict()
        self.collect_stat()
        return self.cnstat_dict

    @multi_asic_util.run_on_multi_asic
    def collect_stat(self):
        """
        Collect the statisitics from all the asics present on the
        device and store in a dict
        """
        self.cnstat_dict.update(self.get_cnstat())

    def get_port_state(self, port_name):
        """
        Get the port state
        """
        full_table_id = FABRIC_PORT_STATUS_TABLE_PREFIX + port_name
        oper_state = self.db.get(self.db.STATE_DB, full_table_id, FABRIC_PORT_STATUS_FIELD)
        if oper_state is not None:
            return oper_state
        return STATUS_NA

    def get_counters(self, counter_bucket_dict, table_id):
        fields = ["0"] * len(counter_bucket_dict)
        for pos, counter_name in counter_bucket_dict.items():
            full_table_id = COUNTER_TABLE_PREFIX + table_id
            counter_data = self.db.get(self.db.COUNTERS_DB, full_table_id, counter_name)
            if counter_data is None:
                 fields[pos] = STATUS_NA
            elif fields[pos] != STATUS_NA:
                fields[pos] = str(int(fields[pos]) + int(counter_data))
        return fields

    def get_cnstat(self):
        """
        Get the counters info from database.
        """
        assert False, 'Need to override this method'

    def cnstat_print(self, cnstat_dict, errors_only=False):
        """
        Print the counter stat.
        """
        assert False, 'Need to override this method'

PortStat = namedtuple("PortStat", "in_cell, in_octet, out_cell, out_octet,\
                               crc, fec_correctable, fec_uncorrectable, symbol_err")
port_counter_bucket_list = [
    'SAI_PORT_STAT_IF_IN_FABRIC_DATA_UNITS',
    'SAI_PORT_STAT_IF_IN_OCTETS',
    'SAI_PORT_STAT_IF_OUT_FABRIC_DATA_UNITS',
    'SAI_PORT_STAT_IF_OUT_OCTETS',
    'SAI_PORT_STAT_IF_IN_ERRORS',
    'SAI_PORT_STAT_IF_IN_FEC_CORRECTABLE_FRAMES',
    'SAI_PORT_STAT_IF_IN_FEC_NOT_CORRECTABLE_FRAMES',
    'SAI_PORT_STAT_IF_IN_FEC_SYMBOL_ERRORS',
    ]
port_counter_bucket_dict = {k : v for k, v in enumerate(port_counter_bucket_list)}

portstat_header_all = ['ASIC', 'PORT', 'STATE',
                       'IN_CELL', 'IN_OCTET', 'OUT_CELL', 'OUT_OCTET',
                       'CRC', 'FEC_CORRECTABLE', 'FEC_UNCORRECTABLE', 'SYMBOL_ERR']
portstat_header_errors_only = ['ASIC', 'PORT', 'STATE',
                               'CRC', 'FEC_CORRECTABLE', 'FEC_UNCORRECTABLE', 'SYMBOL_ERR']

class FabricPortStat(FabricStat):
    def get_cnstat(self):
        counter_port_name_map = self.db.get_all(self.db.COUNTERS_DB, COUNTERS_FABRIC_PORT_NAME_MAP)
        cnstat_dict = OrderedDict()
        if counter_port_name_map is None:
            return cnstat_dict
        for port_name in natsorted(counter_port_name_map):
            cntr = self.get_counters(port_counter_bucket_dict, counter_port_name_map[port_name])
            cnstat_dict[port_name] = PortStat._make(cntr)
        return cnstat_dict

    def cnstat_print(self, cnstat_dict, errors_only=False):
        if len(cnstat_dict) == 0:
            print("Counters %s empty" % self.namespace)
            return

        table = []
        header = None
        asic = multi_asic.get_asic_id_from_name(self.namespace)
        for key, data in cnstat_dict.items():
            port_id = key[len(PORT_NAME_PREFIX):]
            if errors_only:
                header = portstat_header_errors_only
                table.append((asic, port_id, self.get_port_state(key),
                              data.crc, data.fec_correctable, data.fec_uncorrectable,
                              data.symbol_err))
            else:
                header = portstat_header_all
                table.append((asic, port_id, self.get_port_state(key),
                              data.in_cell, data.in_octet, data.out_cell, data.out_octet,
                              data.crc, data.fec_correctable, data.fec_uncorrectable,
                              data.symbol_err))

        print(tabulate(table, header, tablefmt='simple', stralign='right'))
        print()

QueueStat = namedtuple("QueueStat", "curlevel, watermarklevel, curbyte")

queue_counter_bucket_list = [
    'SAI_QUEUE_STAT_CURR_OCCUPANCY_LEVEL',
    'SAI_QUEUE_STAT_WATERMARK_LEVEL',
    'SAI_QUEUE_STAT_CURR_OCCUPANCY_BYTES',
]
queue_counter_bucket_dict = {k : v for k, v in enumerate(queue_counter_bucket_list)}

queuestat_header = ['ASIC', 'PORT', 'STATE', 'QUEUE_ID', 'CURRENT_BYTE', 'CURRENT_LEVEL', 'WATERMARK_LEVEL']

class FabricQueueStat(FabricStat):
    def get_cnstat(self):
        counter_queue_name_map = self.db.get_all(self.db.COUNTERS_DB, COUNTERS_FABRIC_QUEUE_NAME_MAP)
        cnstat_dict = OrderedDict()
        if counter_queue_name_map is None:
            return cnstat_dict
        for port_queue_name in natsorted(counter_queue_name_map):
            cntr = self.get_counters(queue_counter_bucket_dict, counter_queue_name_map[port_queue_name])
            cnstat_dict[port_queue_name] = QueueStat._make(cntr)
        return cnstat_dict

    def cnstat_print(self, cnstat_dict, errors_only=False):
        if len(cnstat_dict) == 0:
            print("Counters %s empty" % self.namespace)
            return

        table = []
        asic = multi_asic.get_asic_id_from_name(self.namespace)
        for key, data in cnstat_dict.items():
            port_name, queue_id = key.split(':')
            port_id = port_name[len(PORT_NAME_PREFIX):]
            table.append((asic, port_id, self.get_port_state(port_name), queue_id,
                          data.curbyte, data.curlevel, data.watermarklevel))

        print(tabulate(table, queuestat_header, tablefmt='simple', stralign='right'))
        print()

def main():
    parser  = argparse.ArgumentParser(description='Display the fabric port state and counters',
                                      formatter_class=argparse.RawTextHelpFormatter,
                                      epilog="""
Examples:
    fabricstat
    fabricstat --namespace asic0
    fabricstat -p -n asic0 -e
    fabricstat -q
    fabricstat -q -n asic0
""")

    parser.add_argument('-q','--queue', action='store_true', help='Display fabric queue stat, otherwise port stat')
    parser.add_argument('-n','--namespace', default=None, help='Display fabric ports counters for specific namespace')
    parser.add_argument('-e', '--errors', action='store_true', help='Display errors')

    args = parser.parse_args()
    queue = args.queue
    namespace = args.namespace
    errors_only = args.errors

    def nsStat(ns, errors_only):
        stat = FabricQueueStat(ns) if queue else FabricPortStat(ns)
        cnstat_dict = stat.get_cnstat_dict()
        stat.cnstat_print(cnstat_dict, errors_only)

    if namespace is None:
        # All asics or all fabric asics
        multi_asic = multi_asic_util.MultiAsic()
        for ns in multi_asic.get_ns_list_based_on_options():
            nsStat(ns, errors_only)
    else:
        # Asic with namespace
        nsStat(namespace, errors_only)

if __name__ == "__main__":
    main()
